import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_from_disk, Dataset
import torch
from peft import PeftConfig, PeftModel
import json
import re
from tqdm import tqdm
import random
import numpy as np
import argparse
from prompt import expert_prompt_text
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(model_path):
    config = PeftConfig.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto")
    lora_model = PeftModel.from_pretrained(model, model_path)
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
    model = lora_model
    model.eval()
    return model, tokenizer

def predict_next_part_with_llama(model, tokenizer, input_text, max_new_tokens=10):
    inputs = tokenizer.encode(input_text, return_tensors='pt')
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, do_sample=False)
    predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return predicted_text

def get_dataset_json(data_path):
    with open(data_path, 'r') as f:
        data = json.load(f)
    return data

def get_dataset(args):
    data_path_config = {
    "StackMIA": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/StackMIAsub",
    "wikiMIA": "/data/home/zhanghx/code/DataContaminate/benchmarks/fine_tuning/",
    # "wikiMIA": "/data/home/zhanghx/code/DataContaminate/benchmarks/fine_tuning/",
    }
    
    dataset_name = args.dataset.split('/')[-1]
    if dataset_name == 'StackMIA':
        ds = load_from_disk(data_path_config["StackMIA"])
        ds = ds.shuffle(args.seed) # TODO seed must be same as the seed in the new_sft.py
        train_test_split = ds.train_test_split(test_size=0.3, seed=args.seed) # TODO seed must be same as the seed in the new_sft.py
        if args.split == 'test':
            data_eval = train_test_split['train'] #7
        elif args.split == 'train':
            data_eval = train_test_split['test'] #3
        dataset = []
        for i in range(len(data_eval)):
            dataset.append({'text': data_eval[i]['snippet'], 'label': data_eval[i]['label']})
    elif dataset_name == 'wikiMIA':
        if args.split == 'test':
            data_path = data_path_config["wikiMIA"] + "test_data.json"
        elif args.split == 'train':
            data_path = data_path_config["wikiMIA"] + "train_data.json"
        dataset = get_dataset_json(data_path)
    else:
        raise ValueError("Unsupported dataset")
    print("dataset numbers: ", len(dataset))
    return dataset

def generate1(dataset):
    labels = []
    texts = []
    for i in range(len(dataset)):
        input_text = dataset[i]['text']
        text = f'''Below is an input may be from pre-training corpus. if the input is seen in the pre-training step, the answer is "Yes", otherwise, it is "No". Please provide an answer. 

        ### Input:
        {input_text}

        ### answer:
        '''
        labels.append(dataset[i]['label'])
        texts.append(text)
    return texts, labels
        
def generate(dataset):
    labels = []
    texts = []
    for i in range(len(dataset)):
        input_text = dataset[i]['text']
        text = f'''{expert_prompt_text}

        ### Input:
        {input_text}

        ### answer:
        '''
        labels.append(dataset[i]['label'])
        texts.append(text)
    return texts, labels

def parallel_predict(input_texts,  model, tokenizer, max_new_tokens=10):
    results = []
    for i in tqdm(range(len(input_texts))):
        input_text = input_texts[i]
        result = predict_next_part_with_llama(model, tokenizer, input_text, max_new_tokens)
        results.append(result)
        print(result)
    return results

def extract_answer(generated_text):
    results = []
    for result in generated_text:
        math = re.search(r'### answer:\n\s*(Yes|No)', result)
        if math:
            results.append(math.group(1).strip())
            # print(math.group(1).strip())
        else:
            results.append(None)
    return results

def evaluate(results, labels):
    correct = 0
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    count = 0
    for i in range(len(results)):
        if results[i] is not None:
            predict = 0 if results[i] == 'No' else 1
            if predict == labels[i]:
                correct += 1
                if predict == 1:
                    tp += 1
                else:
                    tn += 1
            else:
                if predict == 1:
                    fp += 1
                else:
                    fn += 1
        else:
            print(f"Error: no answer found in the generated text.")
            count+=1
    all_total = len(results) - count
    accuracy = 0 if all_total == 0 else correct / all_total
    precision = 0 if tp + fp == 0 else tp / (tp + fp)
    recall = 0 if tp + fn == 0 else tp / (tp + fn)
    f1 = 0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)
    return accuracy, precision, recall, f1, count          


if __name__ == '__main__':
    parser = argparse.ArgumentParser('evaluate')
    parser.add_argument('--dataset', type=str, default="StackMIA", choices=["StackMIA", "wikiMIA"]) 
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--model_path', type=str, default='/data/home/zhanghx/code/DataContaminate/ckpts/model/llama-7b/seed_1/wikiMIA-20240801052942')
    # parser.add_argument('--model_path', type=str, default='/data/home/zhanghx/code/DataContaminate/ckpts/model/newllama-7b/seed_1/answer_0.526')
    
    args = parser.parse_args()
    model, tokenizer = load_model(args.model_path)
    dataset = get_dataset(args)
    texts, labels = generate(dataset)
    results = parallel_predict(texts, model, tokenizer, max_new_tokens=10)
    results = extract_answer(results)
    
    accuracy, precision, recall, f1, count = evaluate(results, labels)
    print("accuracy: ", accuracy)
    print("precision: ", precision)
    print("recall: ", recall)
    print("count: ", count)
    
